import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import safety_gym
import gym
import time
import  core
from utils.logx import EpochLogger
from utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from torch.nn.functional import softplus
torch.autograd.set_detect_anomaly(True)
import torch.optim as optim
import torch.nn.functional as F
import ignite.handlers.param_scheduler as IGN
import pandas as pd
import matplotlib.pyplot as plt
from torch.nn.utils import clip_grad_norm_
import os
os.chdir('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-TRAIN')
import gc
from torch.utils.data import DataLoader, TensorDataset

class Safety_NN(nn.Module):
    def __init__(self, n_state, n_class):
        super(Safety_NN, self).__init__()
        self.layer1 = nn.Linear(n_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_class)
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

def ppo(env_fn, 
        seed=0, 
        epochs=50,
        n_NN = 1,
        n_class = 2,
        n_action = 2,
        file_intest_number_array = [5000000, 5000000], 
        storage_intest_number_array = [5000000, 5000000],
        ce_section = 4, # 1 for nolag 1 and 4 for lag1.5
        ce_checkpoint = 1,
        agent_checkpoint_mode = "error"
        ):

    # Path setting
    pof_data_path = "/home/user/POF_data_" + str(ce_section) +"/"
    
    # Random seed
    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape
    n_observation = obs_dim[0]
    act_dim = env.action_space.shape
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)

    # SNN setting
    SNN_list = [
    torch.nn.DataParallel(Safety_NN(n_observation + n_action, n_class).to(device)) 
    for _ in range(n_NN)
]
    # pt_warmup_local_epoch = 100000
    pt_main_local_epoch = epochs
    pt_loss = nn.CrossEntropyLoss()
    optimizer_list = [optim.AdamW(SNN_list[i].parameters(), lr=1e-4, weight_decay=1e-5, amsgrad=True) for i in range(n_NN)]
    scheduler_list = [optim.lr_scheduler.LinearLR(optimizer_list[i], start_factor=1, end_factor=1e-3, total_iters=pt_main_local_epoch) for i in range(n_NN)]
    # scheduler_list_with_warmup = [IGN.create_lr_scheduler_with_warmup(scheduler_list[i], warmup_start_value=0., warmup_duration=pt_main_local_epoch) for i in range(n_NN)]
    
    # IT setting
    storage_intest_number = storage_intest_number_array[0] + storage_intest_number_array[1]
    storage_intest_state_list = []
    storage_intest_label_list =[]
    safe_intest_index_list = [list() for _ in range(n_NN)]
    unsafe_intest_index_list = [list() for _ in range(n_NN)]

    # ETC setting
    backup_epoch = 0 
    context = [storage_intest_number_array[0], storage_intest_number_array[1]]    
    intest_dataset_path = agent_checkpoint_mode+"/intest/intest_obs.csv"
    obss_df = pd.read_csv(intest_dataset_path, header=None)
    obss_np = obss_df.to_numpy()
    assert(n_action == 2)
    for nn_order in range(n_NN):
        storage_intest_state_list.append([])
        storage_intest_label_list.append([])
        for intest_order in range(storage_intest_number_array[0]):
            storage_intest_state_list[nn_order].append(obss_np[intest_order])
            storage_intest_label_list[nn_order].append([0.0])
            safe_intest_index_list[nn_order].append(intest_order)
        for intest_order in range(storage_intest_number_array[1]):
            storage_intest_state_list[nn_order].append(obss_np[file_intest_number_array[0]+intest_order])
            storage_intest_label_list[nn_order].append([1.0])
            unsafe_intest_index_list[nn_order].append(storage_intest_number_array[0]+intest_order)
        storage_intest_state_list=np.array(storage_intest_state_list)
        storage_intest_label_list=np.array(storage_intest_label_list)
        print(storage_intest_state_list.shape)
        assert storage_intest_state_list[nn_order].shape[0] == storage_intest_number
        assert storage_intest_label_list[nn_order].shape[0] == storage_intest_number
        assert len(safe_intest_index_list[nn_order]) == storage_intest_number_array[0]
        assert len(unsafe_intest_index_list[nn_order]) == storage_intest_number_array[1]  
    del obss_np
    del obss_df
    gc.collect()
        
    print("    START!")
    n_color_array = np.array(["blue", "red"])
    colormap = [0]*storage_intest_number_array[0] + [1]*storage_intest_number_array[1]
    
    assert n_NN == 1
    # x = torch.tensor(storage_intest_state_list[0], dtype=torch.float32, device=device).squeeze(0)
    # y = torch.tensor([int(l[0]) for l in storage_intest_label_list[0]], dtype=torch.long, device=device)
    x = torch.tensor(storage_intest_state_list[0], dtype=torch.float32).squeeze(0).to(device, non_blocking=True)
    y = torch.tensor([int(l[0]) for l in storage_intest_label_list[0]], dtype=torch.long).to(device, non_blocking=True)
    dataset = TensorDataset(x, y)
    
    for i in range(n_NN): optimizer_list[i].zero_grad()
    SNN_list[0].train()
    
    for epoch in range(backup_epoch, epochs):
        b = time.time()
        optimizer_list[0].zero_grad()
        pt_output_tmp = SNN_list[0](x)
        pt_loss_tmp = pt_loss(pt_output_tmp, y)
        pt_loss_tmp.backward()
        optimizer_list[0].step()
        scheduler_list[0].step()
        h = time.time()
        print(h-b)
        if epoch % 1000 == 999:
            SNN_list[0].eval()
            total_intest_output_list = list()
            for m in range(n_NN): total_intest_output_list.append(SNN_list[m](x.to(device)).detach().cpu().numpy())
            figone, axsone = plt.subplots(n_class, figsize=(15, 9))
            for j in range(n_class):
                axsone[j].scatter(range(0,storage_intest_number), total_intest_output_list[0][0:storage_intest_number,j], color=n_color_array[colormap], s=1.0)
                axsone[j].set_title(str(j+1)+'_class')
            figone.suptitle('PRETRAINING STORATE IT OUTPUT PLOT', fontsize=16)
            plt.savefig(pof_data_path + '/data_format/PT_STORAGE_IT_OUTPUT_PLOT_'+str(epoch+1)+'.png')
            plt.close('all')
            plt.clf()
            # Save model
            torch.save(
                {
                    # main info section
                    "epoch": epoch,
                    "context": context,
                    # SNN section
                    "SNN_0": SNN_list[0].state_dict(),
                    "Soptim_0": optimizer_list[0].state_dict(),
                    "Ssch_0": scheduler_list[0].state_dict(), 
                    # internal test case section
                    "storage_intest_number":storage_intest_number,
                },
                f"/home/user/POF_data_{ce_section}/checkpoints/pof-checkpoint-{ce_checkpoint}_epoch-{epoch}.pt", pickle_protocol=4
            )
            print("    SAVE CHECKPOINT --- pof-checkpoint-%d_epoch-%d.pt" %(ce_checkpoint, epoch))
            SNN_list[0].train()
    print("pretrainig end")


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=1)
    parser.add_argument('--steps', type=int, default=4000)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--exp_name', type=str, default='ppo_point_train_')
    parser.add_argument('--mode', type=str, default='error')
    parser.add_argument('--checkpoint', type=str, default='-1')
    args = parser.parse_args()

    from utils.run_utils import setup_logger_kwargs
    if args.mode != "error": args.exp_name += args.mode.split("_")[-1]
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    epochs = 30000
    ppo(lambda : gym.make(args.env), seed=args.seed, epochs=epochs, agent_checkpoint_mode=args.mode)
